{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle as pkl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "pkls = [f for f in os.listdir('./') if '.pkl' in f]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "sparse_weights = {}\n",
    "cnt = 0\n",
    "for p in pkls:\n",
    "    with open(p, 'rb') as f:\n",
    "        weight = pkl.load(f).detach().numpy()\n",
    "        \n",
    "    _, topbot, m, sparsity = p.split('_')\n",
    "    sparsity = sparsity.split('.pkl')[0]\n",
    "    \n",
    "    if sparse_weights.get(sparsity) == None:\n",
    "        sparse_weights[sparsity] = {\n",
    "            \"botl.0\": None,\n",
    "            \"botl.2\": None,\n",
    "            \"botl.4\": None,\n",
    "            \"topl.0\": None,\n",
    "            \"topl.2\": None,\n",
    "            \"topl.4\": None,\n",
    "            \"topl.6\": None,\n",
    "            \"topl.8\": None,\n",
    "        }\n",
    "    \n",
    "    sparse_weights[sparsity][topbot+m] = weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# handpicked\n",
    "sparsity_to_test = [\n",
    "    '0.5045721601363784',\n",
    "    '0.5541957834992964',\n",
    "    '0.599990952078147',\n",
    "    '0.6509844477216149',\n",
    "    '0.7006824838997727',\n",
    "    '0.7513115258686005',\n",
    "    '0.8001901754789479',\n",
    "    '0.8511912815239745',\n",
    "    '0.9004462232113866',\n",
    "    '0.9499996617599308'\n",
    "]\n",
    "\n",
    "# Make dest file\n",
    "def create_dir(dest_dir):\n",
    "    try:\n",
    "        os.stat(dest_dir)\n",
    "    except:\n",
    "        os.makedirs(dest_dir, exist_ok=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dumped: sparse_weights/ts_95/sparsity_0.5/botl.0_h512_v13.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.5/botl.2_h256_v512.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.5/botl.4_h128_v256.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.5/topl.0_h1024_v479.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.5/topl.2_h1024_v1024.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.5/topl.4_h512_v1024.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.5/topl.6_h256_v512.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.5/topl.8_h1_v256.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.55/botl.0_h512_v13.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.55/botl.2_h256_v512.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.55/botl.4_h128_v256.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.55/topl.0_h1024_v479.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.55/topl.2_h1024_v1024.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.55/topl.4_h512_v1024.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.55/topl.6_h256_v512.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.55/topl.8_h1_v256.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.6/botl.0_h512_v13.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.6/botl.2_h256_v512.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.6/botl.4_h128_v256.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.6/topl.0_h1024_v479.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.6/topl.2_h1024_v1024.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.6/topl.4_h512_v1024.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.6/topl.6_h256_v512.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.6/topl.8_h1_v256.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.65/botl.0_h512_v13.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.65/botl.2_h256_v512.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.65/botl.4_h128_v256.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.65/topl.0_h1024_v479.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.65/topl.2_h1024_v1024.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.65/topl.4_h512_v1024.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.65/topl.6_h256_v512.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.65/topl.8_h1_v256.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.7/botl.0_h512_v13.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.7/botl.2_h256_v512.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.7/botl.4_h128_v256.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.7/topl.0_h1024_v479.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.7/topl.2_h1024_v1024.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.7/topl.4_h512_v1024.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.7/topl.6_h256_v512.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.7/topl.8_h1_v256.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.75/botl.0_h512_v13.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.75/botl.2_h256_v512.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.75/botl.4_h128_v256.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.75/topl.0_h1024_v479.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.75/topl.2_h1024_v1024.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.75/topl.4_h512_v1024.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.75/topl.6_h256_v512.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.75/topl.8_h1_v256.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.8/botl.0_h512_v13.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.8/botl.2_h256_v512.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.8/botl.4_h128_v256.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.8/topl.0_h1024_v479.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.8/topl.2_h1024_v1024.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.8/topl.4_h512_v1024.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.8/topl.6_h256_v512.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.8/topl.8_h1_v256.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.85/botl.0_h512_v13.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.85/botl.2_h256_v512.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.85/botl.4_h128_v256.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.85/topl.0_h1024_v479.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.85/topl.2_h1024_v1024.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.85/topl.4_h512_v1024.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.85/topl.6_h256_v512.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.85/topl.8_h1_v256.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.9/botl.0_h512_v13.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.9/botl.2_h256_v512.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.9/botl.4_h128_v256.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.9/topl.0_h1024_v479.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.9/topl.2_h1024_v1024.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.9/topl.4_h512_v1024.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.9/topl.6_h256_v512.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.9/topl.8_h1_v256.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.95/botl.0_h512_v13.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.95/botl.2_h256_v512.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.95/botl.4_h128_v256.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.95/topl.0_h1024_v479.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.95/topl.2_h1024_v1024.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.95/topl.4_h512_v1024.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.95/topl.6_h256_v512.txt\n",
      "dumped: sparse_weights/ts_95/sparsity_0.95/topl.8_h1_v256.txt\n"
     ]
    }
   ],
   "source": [
    "dest_dir = 'sparse_weights/ts_95/'\n",
    "\n",
    "def dump_to_txt(numpy_weights, filepath):\n",
    "    # Should be 2-dim array\n",
    "    assert len(numpy_weights.shape) == 2\n",
    "    \n",
    "    with open(filepath, \"w\") as f:\n",
    "        f.write(\"{} {}\\n\".format(*numpy_weights.shape))\n",
    "        f.writelines(\n",
    "            [\" \".join([str(i) for i in l]) + \"\\n\" for l in numpy_weights]            \n",
    "        )\n",
    "    print(\"dumped: {}\".format(filepath))\n",
    "\n",
    "for sp in sparsity_to_test:\n",
    "    dest_path = os.path.join(dest_dir, \"sparsity_{}\".format(round(float(sp), 2)))\n",
    "    create_dir(dest_path)\n",
    "    \n",
    "    weights = sparse_weights[sp]\n",
    "    \n",
    "    for key, val in weights.items():\n",
    "        h, v = val.shape\n",
    "        file_name = os.path.join(dest_path, \"{}_h{}_v{}.txt\".format(key, h, v))\n",
    "        dump_to_txt(val, file_name)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "file_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ww = sparse_weights['0.5541957834992964']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ww.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "sparse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "round(0.432, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "total_el = 0\n",
    "total_nz = 0\n",
    "\n",
    "for w in weights.keys():\n",
    "    sparsity_pattern = abs(weights[w]) == 0.\n",
    "    n_element = sparsity_pattern.shape[0] * sparsity_pattern.shape[1]\n",
    "    nz = sparsity_pattern.sum()\n",
    "    \n",
    "    total_el += n_element\n",
    "    total_nz += nz\n",
    "    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(40, 20))\n",
    "    fig.suptitle(\"Weights for {}, sparsity: {}\".format(w, sparsity_pattern.sum() / n_element))\n",
    "    ax1.imshow(sparsity_pattern, cmap=\"gray\")\n",
    "    ax2.plot(sparsity_pattern.sum(axis=0))\n",
    "    ax3.plot(sparsity_pattern.sum(axis=1))\n",
    "\n",
    "print('total sparsity: {}'.format(total_nz/total_el))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "total_el = 0\n",
    "total_nz = 0\n",
    "\n",
    "for w in weights.keys():\n",
    "    sparsity_pattern = abs(weights[w]) != 0.\n",
    "    \n",
    "    # Compute block size number of non zeros\n",
    "    h, w = sparsity_pattern.shape\n",
    "    \n",
    "    if h >= 128 and w >= 128:\n",
    "        num_nz_per_block = []\n",
    "        blocksize = 16\n",
    "        for i in range(h // blocksize):\n",
    "            for j in range(w // blocksize):\n",
    "                _block = sparsity_pattern[j*blocksize:(j+1)*blocksize, i*blocksize:(i+1)*blocksize]\n",
    "                if _block.shape == (16, 16):\n",
    "                    num_nz_per_block.append(1 - _block.sum() / (16*16))\n",
    "\n",
    "        n_element = sparsity_pattern.shape[0] * sparsity_pattern.shape[1]\n",
    "        nz = sparsity_pattern.sum()\n",
    "\n",
    "        total_el += n_element\n",
    "        total_nz += nz\n",
    "        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(20, 20))\n",
    "        fig.suptitle(\"Weights for {}, sparsity: {}\".format(w, sparsity_pattern.sum() / n_element))\n",
    "        ax1.imshow(sparsity_pattern, cmap=\"gray\")\n",
    "        ax2.plot(num_nz_per_block)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sparsity_pattern.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sparsity_pattern[0:16, 0:16].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(256//16):\n",
    "    print(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dump_to_txt(numpy_weights, path):\n",
    "    # Should be 2-dim array\n",
    "    assert len(numpy_weights.shape) == 2\n",
    "    \n",
    "    with open(\"sp_weight.txt\", \"w\") as f:\n",
    "        f.writelines(\n",
    "            [\" \".join([str(i) for i in l]) + \"\\n\" for l in numpy_weights]            \n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path = './sparse_weights_0.95/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = weights['botl.0'] == 0.00"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a.s"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
